/**
* @file op_compiler.cpp
*
* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
*/

#include "acl/acl_op_compiler.h"
#include "common/log_inner.h"
#include "compile/op_compile_processor.h"
#include "framework/common/ge_format_util.h"
#include "framework/common/profiling_definitions.h"
#include "op_executor.h"
#include "model/acl_resource_manager.h"
#include "toolchain/profiling_manager.h"
#include "toolchain/resource_statistics.h"
#include "types/acl_op.h"
#include "types/tensor_desc_internal.h"
#include "utils/array_utils.h"
#include "framework/runtime/subscriber/global_profiler.h"
#include "single_op/compile/op_compile_processor.h"

namespace {
constexpr size_t COMPILE_OPT_SIZE = 256U;
std::map<aclCompileOpt, std::string> compileOptMap = {{ACL_PRECISION_MODE, ge::PRECISION_MODE},
                                                      {ACL_PRECISION_MODE_V2, ge::PRECISION_MODE_V2},
                                                      {ACL_AICORE_NUM, ge::AICORE_NUM},
                                                      {ACL_OP_SELECT_IMPL_MODE, ge::OP_SELECT_IMPL_MODE},
                                                      {ACL_OPTYPELIST_FOR_IMPLMODE, ge::OPTYPELIST_FOR_IMPLMODE},
                                                      {ACL_OP_DEBUG_LEVEL, ge::OP_DEBUG_LEVEL},
                                                      {ACL_DEBUG_DIR, ge::DEBUG_DIR},
                                                      {ACL_OP_COMPILER_CACHE_MODE, ge::OP_COMPILER_CACHE_MODE},
                                                      {ACL_OP_COMPILER_CACHE_DIR, ge::OP_COMPILER_CACHE_DIR},
                                                      {ACL_OP_PERFORMANCE_MODE, ge::PERFORMANCE_MODE},
                                                      {ACL_OP_JIT_COMPILE, "ge.jit_compile"},
                                                      {ACL_OP_DETERMINISTIC, "ge.deterministic"},
                                                      {ACL_CUSTOMIZE_DTYPES, ge::CUSTOMIZE_DTYPES},
                                                      {ACL_OP_PRECISION_MODE, "ge.exec.op_precision_mode"},
                                                      {ACL_ALLOW_HF32, "ge.exec.allow_hf32"},
                                                      {ACL_OP_DEBUG_OPTION, "op_debug_option"}};

aclError CheckInput(const char *opType, const int32_t numInputs, const aclTensorDesc *const inputDesc[],
                    const aclDataBuffer *const inputs[], const int32_t numOutputs,
                    const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[],
                    const aclopCompileType compileFlag)
{
    ACL_REQUIRES_NON_NEGATIVE(numInputs);
    ACL_REQUIRES_NON_NEGATIVE(numOutputs);
    if (compileFlag != ACL_COMPILE_SYS && compileFlag != ACL_COMPILE_UNREGISTERED) {
        ACL_LOG_ERROR("[Check][Type]aclopCompile compile type[%d] not support", static_cast<int32_t>(compileFlag));
        acl::AclErrorLogManager::ReportInputError(
            acl::UNSUPPORTED_FEATURE_MSG, std::vector<std::string>({"feature", "reason"}),
            std::vector<std::string>({"compile type", "must be equal to ACL_COMPILE_SYS or ACL_COMPILE_UNREGISTERED"}));
        return ACL_ERROR_API_NOT_SUPPORT;
    }
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(opType);
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numInputs, inputDesc));
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numOutputs, outputDesc));

    ACL_REQUIRES_OK(acl::array_utils::CheckDataBufferArry(numInputs, inputs));
    ACL_REQUIRES_OK(acl::array_utils::CheckDataBufferArry(numOutputs, outputs));
    return ACL_SUCCESS;
}

aclError ConstructAclOp(acl::AclOp &aclOp, const char *opType, const int32_t numInputs,
                        const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[],
                        const int32_t numOutputs, const aclTensorDesc *const outputDesc[],
                        aclDataBuffer *const outputs[], const aclopAttr *attr, const aclopEngineType engineType,
                        const aclopCompileType compileFlag, const char *opPath, const acl::OpExecuteType executeType,
                        const bool isCompile)
{
    aclOp.opType.assign(opType);
    aclOp.numInputs = numInputs;
    aclOp.inputDesc = inputDesc;
    aclOp.numOutputs = numOutputs;
    aclOp.outputDesc = outputDesc;
    aclOp.inputs = inputs;
    aclOp.outputs = outputs;
    aclOp.opAttr = attr;
    aclOp.isCompile = isCompile;
    aclOp.engineType = engineType;
    aclOp.compileType = static_cast<acl::OpCompileType>(compileFlag);
    aclOp.exeucteType = executeType;
    if (compileFlag == ACL_COMPILE_UNREGISTERED) {
        if (opPath == nullptr) {
            ACL_LOG_ERROR("[Check][OpPath]opPath cannot be null while compileFlag is %d",
                          static_cast<int32_t>(compileFlag));
            acl::AclErrorLogManager::ReportInputError(acl::INVALID_NULL_POINTER_MSG,
                                                      std::vector<std::string>({"param"}),
                                                      std::vector<std::string>({"opPath"}));
            return ACL_ERROR_INVALID_PARAM;
        }
        aclOp.opPath.assign(opPath);
    }
    return ACL_SUCCESS;
}

aclError CopyOptValue(char *value, size_t length, const std::string &str)
{
    if (length < str.size() + 1U) {
        ACL_LOG_ERROR("[Check][PARAM] length[%zu] < str_size[%zu] + 1U", length, str.size());
        return ACL_ERROR_FAILURE;
    }
    const auto ret = strncpy_s(value, length, str.c_str(), str.size());
    if (ret != EOK) {
        ACL_LOG_INNER_ERROR("[Copy][Str]call strncpy_s failed, length: %zu, src size: %zu", length, str.size());
        return ACL_ERROR_FAILURE;
    }
    *(value + str.size()) = '\0';
    return ACL_SUCCESS;
}
std::string GetDefaultJitCompileValue(const std::string &version)
{
    static const std::set<std::string> kDisabledVersion = {"Ascend910B1", "Ascend910B2",   "Ascend910B3",
                                                           "Ascend910B4", "Ascend910B4-1", "Ascend910B2C"};
    static const std::set<std::string> kDisabledShortVersion = {"Ascend910_93", "Ascend910_95"};
    std::string opt_value = "enable";
    constexpr size_t kShortVersionLen = 12UL;
    const std::string shortVersion = version.substr(0, kShortVersionLen);
    if ((kDisabledVersion.find(version) != kDisabledVersion.end()) ||
        kDisabledShortVersion.find(shortVersion) != kDisabledShortVersion.end()) {
        opt_value = "disable";
    }
    return opt_value;
}
}  // namespace

aclError aclopCompile(const char *opType, int numInputs, const aclTensorDesc *const inputDesc[], int numOutputs,
                      const aclTensorDesc *const outputDesc[], const aclopAttr *attr, aclopEngineType engineType,
                      aclopCompileType compileFlag, const char *opPath)
{
    ACL_PROFILING_REG(acl::AclProfType::AclopCompile);
    ACL_STAGES_REG(acl::ACL_STAGE_COMP, acl::ACL_STAGE_DEFAULT);
    ACL_REQUIRES_NON_NEGATIVE(numInputs);
    ACL_REQUIRES_NON_NEGATIVE(numOutputs);
    if (compileFlag != ACL_COMPILE_SYS && compileFlag != ACL_COMPILE_UNREGISTERED) {
        ACL_LOG_ERROR("[Check][CompileFlag]aclopCompileType [%d] not support", static_cast<int32_t>(compileFlag));
        acl::AclErrorLogManager::ReportInputError(acl::INVALID_PARAM_MSG, std::vector<std::string>({"param", "value",
            "reason"}), std::vector<std::string>({"compile type", std::to_string(compileFlag), "not in range"}));
        return ACL_ERROR_API_NOT_SUPPORT;
    }
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(opType);
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numInputs, inputDesc));
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numOutputs, outputDesc));
    if (acl::array_utils::IsHostMemTensorDesc(numInputs, inputDesc) != ACL_SUCCESS) {
        ACL_LOG_INNER_ERROR("[Check][TensorDesc]aclopCompile ACL_MEMTYPE_HOST or "
                            "ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT placeMent in inputDesc not support");
        return ACL_ERROR_API_NOT_SUPPORT;
    }
    if (acl::array_utils::IsHostMemTensorDesc(numOutputs, outputDesc) != ACL_SUCCESS) {
        ACL_LOG_INNER_ERROR("[Check][TensorDesc]aclopCompile ACL_MEMTYPE_HOST or "
                            "ACL_MEMTYPE_HOST_COMPILE_INDEPENDENT placeMent in outputDesc not support");
        return ACL_ERROR_API_NOT_SUPPORT;
    }

    ACL_LOG_INFO("start to execute aclopCompile. opType = %s, engineType = %d, compileFlag = %d", opType,
                 static_cast<int32_t>(engineType), static_cast<int32_t>(compileFlag));
    acl::AclOp aclOp;
    aclOp.opType = std::string(opType);
    aclOp.numInputs = numInputs;
    aclOp.inputDesc = inputDesc;
    aclOp.numOutputs = numOutputs;
    aclOp.outputDesc = outputDesc;
    aclOp.opAttr = attr;
    aclOp.isCompile = true;
    aclOp.engineType = engineType;
    aclOp.compileType = static_cast<acl::OpCompileType>(compileFlag);
    if (compileFlag == ACL_COMPILE_UNREGISTERED) {
        if (opPath == nullptr) {
            ACL_LOG_ERROR("[Check][CompileFlag]opPath cannot be null while compileFlag is %d", compileFlag);
            acl::AclErrorLogManager::ReportInputError(acl::INVALID_NULL_POINTER_MSG,
                std::vector<std::string>({"param"}), std::vector<std::string>({"opPath"}));
            return ACL_ERROR_INVALID_PARAM;
        }
        aclOp.opPath = std::string(opPath);
    }
    ACL_LOG_INFO("aclopCompile::aclOp = %s", aclOp.DebugString().c_str());
    return acl::OpCompileProcessor::GetInstance().OpCompile(aclOp);
}

aclError aclopCompileAndExecute(const char *opType, int numInputs, const aclTensorDesc *const inputDesc[],
                                const aclDataBuffer *const inputs[], int numOutputs,
                                const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[],
                                const aclopAttr *attr, aclopEngineType engineType, aclopCompileType compileFlag,
                                const char *opPath, aclrtStream stream)
{
    ACL_PROFILING_REG(acl::AclProfType::AclopCompileAndExecute);
    RT2_PROFILING_SCOPE(gert::profiling::kUnknownName, gert::profiling::kAclCompileAndExecute);
    ACL_REQUIRES_OK(CheckInput(opType, numInputs, inputDesc, inputs, numOutputs, outputDesc, outputs, compileFlag));
    if (acl::array_utils::IsAllTensorEmpty(numOutputs, outputDesc)) {
        ACL_LOG_INFO("all ouput tensor are empty");
        return ACL_SUCCESS;
    }

    acl::AclOp aclOp;
    ACL_REQUIRES_OK(ConstructAclOp(aclOp, opType, static_cast<int32_t>(numInputs), inputDesc, inputs,
                                   static_cast<int32_t>(numOutputs), outputDesc, outputs, attr, engineType, compileFlag,
                                   opPath, acl::OpExecuteType::ACL_OP_EXECUTE, true));

    ACL_LOG_INFO("start to execute aclopCompileAndExecute, opType = %s, engineType = %d, compileFlag = %d", opType,
                 static_cast<int32_t>(engineType), static_cast<int32_t>(compileFlag));

    ACL_LOG_INFO("aclopCompile::aclOp = %s", aclOp.DebugString().c_str());
    auto ret = acl::OpCompileProcessor::GetInstance().OpCompile(aclOp);
    RT2_PROFILING_SCOPE_ELEMENT(aclOp.opModel.profilingIndex);
    if (ret != ACL_SUCCESS) {
        ACL_LOG_INNER_ERROR("build op model failed, result = %d", ret);
        return ret;
    }
    ACL_LOG_INFO("ExecuteAsync::aclOp = %s", aclOp.DebugString().c_str());
    aclOp.isCompile = false;
    return acl::OpExecutor::ExecuteAsync(aclOp, inputs, outputs, stream);
}

aclError aclopCompileAndExecuteV2(const char *opType, int numInputs, aclTensorDesc *inputDesc[],
                                  aclDataBuffer *inputs[], int numOutputs, aclTensorDesc *outputDesc[],
                                  aclDataBuffer *outputs[], aclopAttr *attr, aclopEngineType engineType,
                                  aclopCompileType compileFlag, const char *opPath, aclrtStream stream)
{
    ACL_PROFILING_REG(acl::AclProfType::AclopCompileAndExecuteV2);
    RT2_PROFILING_SCOPE(gert::profiling::kUnknownName, gert::profiling::kAclCompileAndExecuteV2);
    ACL_REQUIRES_OK(CheckInput(opType, numInputs, inputDesc, inputs, numOutputs, outputDesc, outputs, compileFlag));
    if (acl::array_utils::IsAllTensorEmpty(numOutputs, outputDesc)) {
        ACL_LOG_INFO("all ouput tensor are empty");
        return ACL_SUCCESS;
    }
    acl::AclOp aclOp;
    ACL_REQUIRES_OK(ConstructAclOp(aclOp, opType, static_cast<int32_t>(numInputs), inputDesc, inputs,
                                   static_cast<int32_t>(numOutputs), outputDesc, outputs, attr, engineType, compileFlag,
                                   opPath, acl::OpExecuteType::ACL_OP_EXECUTE_REFRESH_OUTPUT_ORI_SHAPE, true));

    ACL_LOG_INFO("start to execute aclopCompileAndExecuteV2, opType = %s, engineType = %d, compileFlag = %d", opType,
                 static_cast<int32_t>(engineType), static_cast<int32_t>(compileFlag));

    ACL_LOG_INFO("aclopCompileV2::aclOp = %s", aclOp.DebugString().c_str());
    auto ret = acl::OpCompileProcessor::GetInstance().OpCompile(aclOp);
    RT2_PROFILING_SCOPE_ELEMENT(aclOp.opModel.profilingIndex);
    if (ret != ACL_SUCCESS) {
        ACL_LOG_INNER_ERROR("build op model failed, result = %d", ret);
        return ret;
    }
    ACL_LOG_INFO("ExecuteAsyncV2::aclOp = %s", aclOp.DebugString().c_str());
    aclOp.isCompile = false;
    return acl::OpExecutor::ExecuteAsync(aclOp, inputs, outputs, stream);
}

aclError aclSetCompileopt(aclCompileOpt opt, const char *value)
{
    ACL_LOG_INFO("start to execute aclSetCompileopt");
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(value);

    if (opt == ACL_AUTO_TUNE_MODE) {
        ACL_LOG_INNER_ERROR("The Auto Tune function has been discarded. Please use the AOE tool for tuning.");
        return ACL_ERROR_FEATURE_UNSUPPORTED;
    }

    std::string optStr = compileOptMap.find(opt) != compileOptMap.cend() ? compileOptMap[opt] : "";
    if (optStr.empty()) {
        ACL_LOG_INNER_ERROR("[Check][Opt]Can not find any options[%d] valid in enum aclCompileOpt, "
                            "please check input option.",
                            opt);
        return ACL_ERROR_INTERNAL_ERROR;
    }

    std::string valueStr = std::string(value);
    if (opt == ACL_OP_JIT_COMPILE) {
        valueStr = (valueStr == "enable") ? "1" : "0";
        int32_t flag = (valueStr == "1") ? 1 : 0;
        acl::OpCompileProcessor::GetInstance().SetJitCompileFlag(flag);

        std::string k = "ge.shape_generalized";
        std::string v = (valueStr == "1") ? "0" : "1";
        ACL_REQUIRES_OK(acl::OpCompileProcessor::GetInstance().SetCompileOpt(k, v));
        ACL_LOG_INFO("Set compile option [%s] and value [%s]", k.c_str(), v.c_str());
    }
    ACL_REQUIRES_OK(acl::OpCompileProcessor::GetInstance().SetCompileOpt(optStr, valueStr));
    ACL_LOG_INFO("Set compile option [%s] and value [%s]", optStr.c_str(), valueStr.c_str());
    return ACL_SUCCESS;
}

/**
* 规避方案：提供aclGetCompileopt接口，获取jit_compile的值，正式方案工作里0.5K，因过点来不及，经PL,PM,commiter,se同意，合入该规避方案。
* 方案详述：临时判断soc ersion返回不同的值，且当前之支持获取ACL_OP_JIT_COMPILE的值。正式方案需要根据配置文件返回对应的value
* 方案约束：无
**/
size_t aclGetCompileoptSize(aclCompileOpt opt)
{
    if (opt == ACL_OP_JIT_COMPILE) {
        (void)opt;
        return COMPILE_OPT_SIZE;
    }
    // 返回value实际长度
    std::string value;
    const std::string &optStr = compileOptMap.find(opt) != compileOptMap.cend() ? compileOptMap[opt] : "";
    if (optStr.empty()) {
        return 0UL;
    }
    (void)acl::OpCompileProcessor::GetInstance().GetCompileOpt(optStr, value);
    return value.empty() ? 0UL : (value.size() + 1UL);
}

/**
 * 规避方案：提供aclGetCompileopt接口，获取jit_compile的值，正式方案工作里0.5K，因过点来不及，经PL,PM,commiter,se同意，合入该规避方案。
 * 方案详述：临时判断socVersion返回不同的值，且当前之支持获取ACL_OP_JIT_COMPILE、ACL_OP_DEBUG_OPTION的值。
 * 正式方案需要根据配置文件返回对应的value
 * 方案约束：无
 **/
aclError aclGetCompileopt(aclCompileOpt opt, char *value, size_t length)
{
    ACL_REQUIRES_NOT_NULL(value);
    if (opt == ACL_OP_DEBUG_OPTION) {
        std::string optValue;
        aclError ret = acl::OpCompileProcessor::GetInstance().GetCompileOpt(compileOptMap[opt], optValue);
        if (ret != ACL_SUCCESS) {
            return ACL_ERROR_API_NOT_SUPPORT;
        }
        return CopyOptValue(value, length, optValue);
    }
    if (opt == ACL_OP_JIT_COMPILE) {
        const auto &socVersion = acl::GetSocVersion();
        return CopyOptValue(value, length, GetDefaultJitCompileValue(socVersion));
    }
    return ACL_ERROR_API_NOT_SUPPORT;
}

aclError aclopSetCompileFlag(aclOpCompileFlag flag)
{
    ACL_LOG_INFO("start to execute aclopSetCompileFlag, flag is %d", static_cast<int32_t>(flag));
    ACL_REQUIRES_OK(acl::OpCompileProcessor::GetInstance().SetCompileFlag(static_cast<int32_t>(flag)));
    return ACL_SUCCESS;
}

aclError aclGenGraphAndDumpForOp(const char *opType, int numInputs, const aclTensorDesc *const inputDesc[],
                                 const aclDataBuffer *const inputs[], int numOutputs,
                                 const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[],
                                 const aclopAttr *attr, aclopEngineType engineType, const char *graphDumpPath,
                                 const aclGraphDumpOption *graphDumpOpt)
{
    ACL_PROFILING_REG(acl::AclProfType::AclGenGraphAndDumpForOp);
    ACL_STAGES_REG(acl::ACL_STAGE_COMP_AND_EXEC, acl::ACL_STAGE_DEFAULT);
    ACL_LOG_INFO("start to execute aclGenGraphAndDumpForOp");
    if (graphDumpOpt != nullptr) {
        ACL_LOG_ERROR("[Check][PARAM]graphDumpOpt only support nullptr currently");
        const std::string errMsg = acl::AclErrorLogManager::FormatStr("only support nullptr currently");
        const char_t *argList[] = {"param", "value", "reason"};
        const char_t *argVal[] = {"dstBatchPicDescs height", "not nullptr", errMsg.c_str()};
        acl::AclErrorLogManager::ReportInputErrorWithChar(acl::INVALID_PARAM_MSG, argList, argVal, 3U);
        return ACL_ERROR_INVALID_PARAM;
    }
    ACL_REQUIRES_NON_NEGATIVE(numInputs);
    ACL_REQUIRES_NON_NEGATIVE(numOutputs);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(opType);
    ACL_REQUIRES_NOT_NULL_WITH_INPUT_REPORT(graphDumpPath);
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numInputs, inputDesc));
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numOutputs, outputDesc));
    if (acl::array_utils::IsAllTensorEmpty(numOutputs, outputDesc)) {
        ACL_LOG_INFO("all ouput tensor are empty");
        return ACL_SUCCESS;
    }

    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numInputs, inputs));
    ACL_REQUIRES_OK(acl::array_utils::CheckPtrArray(numOutputs, outputs));

    acl::AclOp aclOp;
    aclOp.opType.assign(opType);
    aclOp.numInputs = numInputs;
    aclOp.inputDesc = inputDesc;
    aclOp.numOutputs = numOutputs;
    aclOp.outputDesc = outputDesc;
    aclOp.inputs = inputs;
    aclOp.outputs = outputs;
    aclOp.opAttr = attr;
    aclOp.isCompile = true;
    aclOp.engineType = engineType;

    ACL_LOG_INFO("aclopCompile::aclOp = %s", aclOp.DebugString().c_str());
    return acl::OpCompileProcessor::GetInstance().OpCompileAndDump(aclOp, graphDumpPath, graphDumpOpt);
}

aclGraphDumpOption *aclCreateGraphDumpOpt()
{
    ACL_ADD_APPLY_TOTAL_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_GRAPH_DUMP_OPTION);
    ACL_ADD_APPLY_SUCCESS_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_GRAPH_DUMP_OPTION);
    return new (std::nothrow) aclGraphDumpOption();
}

aclError aclDestroyGraphDumpOpt(const aclGraphDumpOption *graphDumpOpt)
{
    ACL_ADD_RELEASE_TOTAL_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_GRAPH_DUMP_OPTION);
    if (graphDumpOpt == nullptr) {
        return ACL_ERROR_INVALID_PARAM;
    }

    delete graphDumpOpt;
    ACL_ADD_RELEASE_SUCCESS_COUNT(acl::ACL_STATISTICS_CREATE_DESTROY_GRAPH_DUMP_OPTION);
    return ACL_SUCCESS;
}
